import torch
import torch.nn as nn
import torch.nn.functional as F
from models.PreActRobustmodel import PreActRobustNetwork
from torchprofile import profile_macs
import argparse
import mlconfig
import copy
import os


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compound Scaling')
    parser.add_argument('--expected_depth', type=int, default=20)
    parser.add_argument('--depth_ratio', type=float, default=0.7, help="0.7 is the best in our paper ")
    parser.add_argument('--block_type', type=str, default='robustresblock')
    parser.add_argument('--save_name', type=str, default='arch_001')
    args = parser.parse_args()
    expected_depth = args.expected_depth
    depth_ratio = args.depth_ratio

    expected_width = round(expected_depth / depth_ratio * (1 - depth_ratio))
    depth_scaling_rule = [0.4, 0.4, 0.2]   #
    width_scaling_rule = [0.35, 0.45, 0.2]

    depth_stage1, depth_stage2, depth_stage3 = round(expected_depth*depth_scaling_rule[0]), \
                                               round(expected_depth*depth_scaling_rule[1]), \
                                               round(expected_depth*depth_scaling_rule[2])
    if (depth_stage1 + depth_stage2 + depth_stage3) > expected_depth:
        depth_stage3 = depth_stage3 - 1
    if (depth_stage1 + depth_stage2 + depth_stage3) > expected_depth:
        depth_stage1 = depth_stage1 - 1
    if (depth_stage1 + depth_stage2 + depth_stage3) > expected_depth:
        depth_stage2 = depth_stage2 - 1
    if (depth_stage1 + depth_stage2 + depth_stage3) < expected_depth:
        depth_stage2 += 1
    if (depth_stage1 + depth_stage2 + depth_stage3) < expected_depth:
        depth_stage1 += 1
    if (depth_stage1 + depth_stage2 + depth_stage3) < expected_depth:
        depth_stage3 += 1
    assert (depth_stage1 + depth_stage2 + depth_stage3) <= expected_depth, \
        "{} != {} ".format((depth_stage1 + depth_stage2 + depth_stage3), expected_depth)

    depth = [depth_stage1, depth_stage2, depth_stage3]
    width_stage1, width_stage2, width_stage3 = round(expected_width * width_scaling_rule[0]), \
                                               round(expected_width * width_scaling_rule[1]), \
                                               round(expected_width * width_scaling_rule[2])
    # If the width larger than expected width
    if (width_stage1 + width_stage2 + width_stage3) > expected_width:
        width_stage3 = width_stage3 - 1
    if (width_stage1 + width_stage2 + width_stage3) > expected_width:
        width_stage1 = width_stage1 - 1
    if (width_stage1 + width_stage2 + width_stage3) > expected_width:
        width_stage2 = width_stage2 - 1
    # If width smaller than expected width
    if (width_stage1 + width_stage2 + width_stage3) < expected_width:
        width_stage2 += 1
    if (width_stage1 + width_stage2 + width_stage3) < expected_width:
        width_stage1 += 1
    if (width_stage1 + width_stage2 + width_stage3) < expected_width:
        width_stage3 += 1
    assert (width_stage1 + width_stage2 + width_stage3) <= expected_width, \
        "{} + {} + {}  {} ".format(width_stage1, width_stage2, width_stage3, expected_width)

    mulpliers = [width_stage1, width_stage2, width_stage3]
    channels = [16, 16*mulpliers[0], 32*mulpliers[1], 64*mulpliers[2]]
    expected_depth, expected_width = sum(depth), sum(mulpliers)
    # default setting
    drop_rate_config = (0.0, 0.0, 0.0)
    stride_config = [1, 2, 2]
    num_classes = 10
    kernel_size = [3, 3, 3]
    widen_factor = 9
    activations = ('ReLU', 'ReLU', 'ReLU')
    normalizations = ('BatchNorm', 'BatchNorm', 'BatchNorm')
    if args.block_type == 'basic_block':
        cardinality, base_width, scales = 1, 160, 1
        block_types, se_reduction = ['basic_block', 'basic_block', 'basic_block'], 64
    elif args.block_type == "robustresblock":
        cardinality, base_width, scales = 4, 10, 8
        block_types, se_reduction = ['sebottle2neck', 'sebottle2neck', 'sebottle2neck'], 64
    else:
        assert "Block@{} is not supported yet".format(args.block_type)
    #
    model = PreActRobustNetwork(
        num_classes=num_classes, channel_configs=channels, depth_configs=depth,
        stride_config=stride_config, stem_stride=1,
        drop_rate_config=drop_rate_config,
        kernel_size_configs=kernel_size,
        zero_init_residual=False,
        block_types=block_types,
        activations=activations,
        normalizations=normalizations,
        is_imagenet=False,
        use_init=True,
        cardinality=cardinality,
        base_width=base_width,
        widen_factor=widen_factor,
        scales=scales,
        se_reduction=se_reduction,
    )
    print(model)
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

    data = torch.rand(1, 3, 32, 32)
    out = model(data)
    flops = profile_macs(model, data) / 1e9
    print(expected_depth / (expected_depth + expected_width))
    print('depth@{}-{}-width@{}-{}-channels@{}-block@{}-params = {:.3f}, flops = {:.3f}'.
          format(expected_depth, depth, expected_width, mulpliers, channels, block_types[0], param_count, flops))

    # Write the finding setting into config format
    import glob
    template_config = mlconfig.load("./configs/robustresnets/cifar10/5G_arch_001.yaml")
    config = copy.deepcopy(template_config)
    config['model']['block_types'] = block_types
    config['model']["kernel_size_configs"] = kernel_size
    config['model']["activations"] = activations
    config['model']["normalizations"] = normalizations
    if args.block_type == 'robustresblock':
        config['model']["scales"], config["model"]["base_width"] = scales, base_width
        config['model']['cardinality'], config["model"]["se_reduction"] = cardinality, se_reduction

    config['model']['depth_configs'] = [int(depth[0]), int(depth[1]), int(depth[2])]
    config['model']['channel_configs'] = channels
    config["Params"] = "{:.2f}".format(param_count)
    config["FLOPs"] = "{:.2f}".format(flops)
    config.save(os.path.join("./configs/tailored_archs/", "{}.yaml".format(args.save_name)),
                default_flow_style=False, sort_keys=False, allow_unicode=False)


